import numpy as np

class ActionRobustQ:
    def __init__(self,
                 n_state,
                 n_action,
                 rho,
                 epsilon,
                 alpha,
                 gamma):
        self.Q_table = np.zeros([n_state, n_action])  # Initialize Q(s,a) as all zeros
        self.n_action = n_action
        self.rho = rho
        self.alpha = alpha
        self.gamma = gamma  # discount factor
        self.epsilon = epsilon  # epsilon-greedy

    def take_action(self, state: int, h: int, is_train: bool) -> int:  # take the next action
        if is_train:
            return np.random.randint(self.n_action) if np.random.random() < self.epsilon else np.argmax(self.Q_table[state])
        else:
            return np.argmax(self.Q_table[state])


    def update(self, s0, a0, r, s1, h) -> None:
        td_error = r + self.gamma * ((1 - self.rho) * self.Q_table[s1].max() + self.rho * self.Q_table[s1].min()) - self.Q_table[s0, a0]
        self.Q_table[s0, a0] += self.alpha * td_error

    def update_qv(self):
        pass